# 一米正射辅助数据处理类
import time
import math
import numpy as np
from osgeo import gdal
from xml.etree.ElementTree import ElementTree
from scipy.optimize import leastsq


class OrthoAuxData:
    def __init__(self):
        pass

    @staticmethod
    def time_stamp(tm):
        list = tm.split(':')
        sec = math.ceil(float(list[2]))
        tm1 = list[0] + ':' + list[1] + ':' + str(sec)
        tmArr = time.strptime(tm1, "%Y-%m-%d %H:%M:%S")
        # tmArr = time.strptime(tm1, "%Y-%m-%d %H:%M:%S.%f")
        ts = float(time.mktime(tmArr))  # 转换为时间戳
        return ts

    @staticmethod
    def read_meta(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        T = []
        Xs = []
        Ys = []
        Zs = []
        Vsx = []
        Vsy = []
        Vsz = []
        GPS_data = root.find('GPS')
        for child in GPS_data:
            Xs.append(float(child.find('xPosition').text))
            Ys.append(float(child.find('yPosition').text))
            Zs.append(float(child.find('zPosition').text))
            Vsx.append(float(child.find('xVelocity').text))
            Vsy.append(float(child.find('yVelocity').text))
            Vsz.append(float(child.find('zVelocity').text))
            tm = child.find('TimeStamp').text
            ts = OrthoAuxData.time_stamp(tm)
            T.append(ts)
        meta_data = [Xs, Ys, Zs, Vsx, Vsy, Vsz]
        # return T, Xs, Ys, Zs, Vsx, Vsy, Vsz
        return T, meta_data

    @staticmethod
    def read_control_points(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        imageinfo = root.find('imageinfo')
        center = imageinfo.find('center')
        corner = imageinfo.find('corner')
        ctrl_pts = [[] for i in range(2)]
        ctrl_pts[0].append(float(center.find('longitude').text))
        ctrl_pts[1].append(float(center.find('latitude').text))

        for child in corner:
            ctrl_pts[0].append(float(child.find('longitude').text))
            ctrl_pts[1].append(float(child.find('latitude').text))

        return ctrl_pts

    @staticmethod
    def read_dem(dem_resampled_path, flag=1):
        in_ds = gdal.Open(dem_resampled_path)
        gt = list(in_ds.GetGeoTransform())
        bands_num = in_ds.RasterCount
        # x_size = in_ds.RasterXSize
        # y_size = in_ds.RasterYSize
        x_size = in_ds.RasterXSize
        y_size = in_ds.RasterYSize
        c0 = x_size//2
        r0 = y_size//2
        x_size0 = x_size//1000
        y_size0 = y_size//1000
        gt[0] = gt[0] + c0 * gt[1]
        gt[3] = gt[3] + r0 * gt[5]
        pstn_arr = np.zeros([y_size0, x_size0, 3], dtype=np.float)
        for i in range(1, bands_num + 1):
            data = in_ds.GetRasterBand(i).ReadAsArray(c0, r0, x_size0, y_size0)
            data = data / 255 * 1024
            for y in range(y_size0):
                for x in range(x_size0):
                    longitude = gt[0] + x * gt[1]
                    latitude = gt[3] + y * gt[5]
                    altitude = data[y, x]
                    if flag == 1:
                        pstn = OrthoAuxData.LLA2XYZ(longitude, latitude, altitude)
                    else:
                        pstn = [longitude, latitude, altitude]
                    pstn_arr[y, x, 0] = pstn[0]
                    pstn_arr[y, x, 1] = pstn[1]
                    pstn_arr[y, x, 2] = pstn[2]

        del in_ds, data
        return pstn_arr

    @staticmethod
    def orbit_fitting(time_array, meta_data):
        # 最小二乘法求解轨道参数
        T0 = time_array[0]
        t = []
        for i in range(len(time_array)):
            t.append(time_array[i]-T0)

        def func(p, x):
            w3, w2, w1, w0 = p
            return w3*x**3 + w2*x**2 + w1*x + w0

        def error(p, x, y):
            return func(p, x) - y

        orbital_paras = []
        for j in range(len(meta_data)):
            p0 = [1, 2, 3, 4]
            x = np.array(t)
            y = np.array(meta_data[j])
            Para = leastsq(error, p0, args=(x, y))
            orbital_paras.append(Para[0])
            print(Para[0], Para[1])

        return orbital_paras, T0

    @staticmethod
    def get_PRF(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        sensor = root.find('sensor')
        waveParams = sensor.find('waveParams')
        PRF = float(waveParams.find('wave').find('prf').text)
        return PRF

    @staticmethod
    def get_delta_R(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        sensor = root.find('sensor')
        pulseWidth = float(sensor.find('waveParams').find('wave').find('pulseWidth').text)
        c = 300000000
        delta_R = c * pulseWidth / (1000000000 * 2)
        return delta_R

    @staticmethod
    def get_doppler_rate_coef(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        processinfo = root.find('processinfo')
        doppler = processinfo.find('DopplerRateValuesCoefficients')
        t0 = float(processinfo.find('DopplerParametersReferenceTime').text)
        r0 = float(doppler.find('r0').text)
        r1 = float(doppler.find('r1').text)
        r2 = float(doppler.find('r2').text)
        r3 = float(doppler.find('r3').text)
        r4 = float(doppler.find('r4').text)

        return t0, np.array([r0, r1, r2, r3, r4]).reshape(5, 1)

    @staticmethod
    def get_doppler_center_coef(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        processinfo = root.find('processinfo')
        doppler = processinfo.find('DopplerCentroidCoefficients')
        b0 = float(doppler.find('d0').text)
        b1 = float(doppler.find('d1').text)
        b2 = float(doppler.find('d2').text)
        return b0, b1, b2

    @staticmethod
    def get_lamda(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        sensor = root.find('sensor')
        λ = float(sensor.find('lamda').text)
        return λ

    @staticmethod
    def get_t0(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        imageinfo = root.find('imageinfo')
        tm = imageinfo.find('imagingTime').find('start').text
        t0 = OrthoAuxData.time_stamp(tm)
        return t0

    @staticmethod
    def get_start_and_end_time(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        imageinfo = root.find('imageinfo')
        tm0 = imageinfo.find('imagingTime').find('start').text
        tm1 = imageinfo.find('imagingTime').find('end').text
        starttime = OrthoAuxData.time_stamp(tm0)
        endtime = OrthoAuxData.time_stamp(tm1)
        return starttime, endtime

    @staticmethod
    def get_width_and_height(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        imageinfo = root.find('imageinfo')
        width = int(imageinfo.find('width').text)
        height = int(imageinfo.find('height').text)
        return width, height

    @staticmethod
    def get_R0(meta_file_path):
        tree = ElementTree()
        tree.parse(meta_file_path)
        root = tree.getroot()
        imageinfo = root.find('imageinfo')
        R0 = float(imageinfo.find('nearRange').text)
        return R0

    @staticmethod
    def get_h():
        h = 6.6
        return h

    @staticmethod
    def LLA2XYZ(longitude, latitude, altitude):
        '''
        WGS-84坐标系下：经纬度坐标转空间直角坐标
        '''
        # 经纬度余弦值
        cosLat = math.cos(latitude * math.pi / 180)
        sinLat = math.sin(latitude * math.pi /180)
        cosLon = math.cos(longitude * math.pi /180)
        sinLon = math.sin(longitude * math.pi /180)
        # WGS84坐标系参数
        rad = 6378137.0  #地球赤道平均半径
        f = 1.0/298.257224  #WGS84椭球扁率
        C = 1.0/math.sqrt(cosLat*cosLat + (1-f)*(1-f)*sinLat*sinLat)
        S = (1-f)*(1-f)*C
        h = altitude
        # 计算XYZ坐标
        X = (rad * C + h) * cosLat * cosLon
        Y = (rad * C + h) * cosLat * sinLon
        Z = (rad * S + h) * sinLat
        # return np.array([X, Y, Z]).reshape(1,3)
        return [X, Y, Z]

    @staticmethod
    def XYZ2LLA(X, Y, Z):
        ''' 大地坐标系转经纬度
        适用于WGS84坐标系
        args:
            x,y,z
        return:
            lat,long,altitude
        '''
        # WGS84坐标系的参数
        a = 6378137.0  # 椭球长半轴
        b = 6356752.314245  # 椭球短半轴
        ea = np.sqrt((a ** 2 - b ** 2) / a ** 2)
        eb = np.sqrt((a ** 2 - b ** 2) / b ** 2)
        p = np.sqrt(X ** 2 + Y ** 2)
        theta = np.arctan2(Z * a, p * b)

        # 计算经纬度及海拔
        longitude = np.arctan2(Y, X)
        latitude = np.arctan2(Z + eb ** 2 * b * np.sin(theta) ** 3, p - ea ** 2 * a * np.cos(theta) ** 3)
        N = a / np.sqrt(1 - ea ** 2 * np.sin(latitude) ** 2)
        altitude = p / np.cos(latitude) - N

        # return np.array([np.degrees(latitude), np.degrees(longitude), altitude])
        return [np.degrees(longitude), np.degrees(latitude), altitude]

    @staticmethod
    def world2Pixel(geoMatrix, x, y):
        """
        使用GDAL库的geomatrix对象((gdal.GetGeoTransform()))计算地理坐标的像素位置
        """
        ulx = geoMatrix[0]
        uly = geoMatrix[3]
        xDist = geoMatrix[1]
        yDist = geoMatrix[5]
        rtnX = geoMatrix[2]
        rtnY = geoMatrix[4]
        pixel = int((x - ulx) / xDist)
        line = int((uly - y) / abs(yDist))
        return pixel, line

    @staticmethod
    def sar_intensity_synthesis(in_sar_tif, out_sar_tif):
        # 获取SLC格式SAR影像的相关信息
        in_ds = gdal.Open(in_sar_tif)
        bands_num = in_ds.RasterCount
        rows = in_ds.RasterYSize
        columns = in_ds.RasterXSize
        proj = in_ds.GetProjection()
        geotrans = in_ds.GetGeoTransform()

        # 创建输出的SAR强度图
        gtiff_driver = gdal.GetDriverByName('GTiff')
        out_ds = gtiff_driver.Create(out_sar_tif, columns, rows, 1)
        out_ds.SetProjection(proj)
        out_ds.SetGeoTransform(geotrans)

        # 输出SAR强度图
        in_data1 = in_ds.GetRasterBand(1).ReadAsArray(0, 0, columns, rows)
        in_data1 = in_data1/10
        in_data1 = np.power(10, in_data1)
        in_data2 = in_ds.GetRasterBand(2).ReadAsArray(0, 0, columns, rows)
        in_data2 = in_data2 / 10
        in_data2 = np.power(10, in_data2)
        out_data = np.sqrt(in_data1**2 + in_data2**2)
        out_ds.GetRasterBand(1).WriteArray(out_data)

        del in_ds, out_ds
